import torch.multiprocessing as mp
if mp.get_start_method(allow_none=True) is None:
    mp.set_start_method('spawn', force=True)  # or 'forkserver'

import argparse
import os
import time
from IPython import embed

import matplotlib.pyplot as plt
import torch
from torchvision.transforms import transforms

import numpy as np
import common_args
import random
from dataset import Dataset as Dataset_old
from dataset_new import Dataset, ImageDataset, Dataset_wt, Dataset_pred_reward, Dataset_pred_reward_opt_a
from net import Transformer, ImageTransformer, Transformer_new, Transformer_new_opt_a
from utils import (
    build_bandit_data_filename,
    build_bandit_model_filename,
    build_linear_bandit_data_filename,
    build_linear_bandit_model_filename,
    build_darkroom_data_filename,
    build_darkroom_model_filename,
    build_miniworld_data_filename,
    build_miniworld_model_filename,
    worker_init_fn,
)
from tqdm import tqdm, trange
import logging
from torch.utils.tensorboard import SummaryWriter

from torch.optim.lr_scheduler import StepLR

import ipdb

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


if __name__ == '__main__':
    if not os.path.exists('figs/loss'):
        os.makedirs('figs/loss', exist_ok=True)
    if not os.path.exists('/media/external/subho/DPT/models'):
        os.makedirs('/media/external/subho/DPT/models', exist_ok=True)

    parser = argparse.ArgumentParser()
    common_args.add_dataset_args(parser)
    common_args.add_model_args(parser)
    common_args.add_train_args(parser)

    parser.add_argument('--seed', type=int, default=0)

    args = vars(parser.parse_args())
    print("Args: ", args)

    env = args['env']
    n_envs = args['envs']
    n_hists = args['hists']
    n_samples = args['samples']
    horizon = args['H']
    dim = args['dim']
    state_dim = dim
    action_dim = dim
    n_embd = args['embd']
    n_head = args['head']
    n_layer = args['layer']
    lr = args['lr']
    shuffle = args['shuffle']
    dropout = args['dropout']
    var = args['var']
    cov = args['cov']
    num_epochs = args['num_epochs']
    seed = args['seed']
    lin_d = args['lin_d']
    batch_size = args['batch_size']
    
    tmp_seed = seed
    if seed == -1:
        tmp_seed = 0


    torch.manual_seed(tmp_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(tmp_seed)
        torch.cuda.manual_seed_all(tmp_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(tmp_seed)
    random.seed(tmp_seed)

    if shuffle and env == 'linear_bandit':
        raise Exception("Are you sure you want to shuffle on the linear bandit? Data collected from an adaptive algorithm in a stochastic setting can bias the learner if shuffled.")

    dataset_config = {
        'n_hists': n_hists,
        'n_samples': n_samples,
        'horizon': horizon,
        'dim': dim,
    }
    model_config = {
        'shuffle': shuffle,
        'lr': lr,
        'dropout': dropout,
        'n_embd': n_embd,
        'n_layer': n_layer,
        'n_head': n_head,
        'n_envs': n_envs,
        'n_hists': n_hists,
        'n_samples': n_samples,
        'horizon': horizon,
        'dim': dim,
        'seed': seed,
    }
    if env == 'bandit':
        state_dim = 1

        dataset_config.update({'var': var, 'cov': cov, 'type': 'uniform'})
        path_train = build_bandit_data_filename(
            env, n_envs, dataset_config, mode=0)
        path_test = build_bandit_data_filename(
            env, n_envs, dataset_config, mode=1)

        model_config.update({'var': var, 'cov': cov})
        filename = build_bandit_model_filename(env, model_config)

    elif env == 'bandit_thompson':
        state_dim = 1

        dataset_config.update({'var': var, 'cov': cov, 'type': 'bernoulli'})
        path_train = build_bandit_data_filename(
            env, n_envs, dataset_config, mode=0)
        path_test = build_bandit_data_filename(
            env, n_envs, dataset_config, mode=1)

        model_config.update({'var': var, 'cov': cov})
        filename = build_bandit_model_filename(env, model_config)

    elif env == 'linear_bandit':
        state_dim = 1

        dataset_config.update({'lin_d': lin_d, 'var': var, 'cov': cov})
        path_train = build_linear_bandit_data_filename(
            env, n_envs, dataset_config, mode=0)
        path_test = build_linear_bandit_data_filename(
            env, n_envs, dataset_config, mode=1)

        model_config.update({'lin_d': lin_d, 'var': var, 'cov': cov})
        filename = build_linear_bandit_model_filename(env, model_config)

    elif env == 'linear_bandit_exclude':
        state_dim = 1

        dataset_config.update({'lin_d': lin_d, 'var': var, 'cov': cov})
        path_train = build_linear_bandit_data_filename(
            env, n_envs, dataset_config, mode=0)
        path_test = build_linear_bandit_data_filename(
            env, n_envs, dataset_config, mode=1)

        model_config.update({'lin_d': lin_d, 'var': var, 'cov': cov})
        filename = build_linear_bandit_model_filename(env, model_config)

    elif env in ['linear_bandit_new_train', 'linear_bandit_train_lookahead', 'linear_bandit_train_lookahead_wt', 'linear_bandit_train_lookahead_mix', 'linear_bandit_train_original', 'linear_bandit_train_original_emp_opt', 'linear_bandit_train_lookahead_pred_reward', 'linear_bandit_train_lookahead_pred_reward_opt_a', 'linear_bandit_train_AD']:
        state_dim = 1

        dataset_config.update({'lin_d': lin_d, 'var': var, 'cov': cov})
        path_train = build_linear_bandit_data_filename(
            env, n_envs, dataset_config, mode=0)
        path_test = build_linear_bandit_data_filename(
            env, n_envs, dataset_config, mode=1)

        model_config.update({'lin_d': lin_d, 'var': var, 'cov': cov})
        filename = build_linear_bandit_model_filename(env, model_config)


    elif env.startswith('darkroom'):
        state_dim = 2
        action_dim = 5

        dataset_config.update({'rollin_type': 'uniform'})
        path_train = build_darkroom_data_filename(
            env, n_envs, dataset_config, mode=0)
        path_test = build_darkroom_data_filename(
            env, n_envs, dataset_config, mode=1)

        filename = build_darkroom_model_filename(env, model_config)

    elif env == 'miniworld':
        state_dim = 2   # direction vector is 2D, no position included
        action_dim = 4

        dataset_config.update({'rollin_type': 'uniform'})

        increment = 5000
        starts = np.arange(0, n_envs, increment)
        starts = np.array(starts)
        ends = starts + increment - 1

        paths_train = []
        paths_test = []
        for start_env_id, end_env_id in zip(starts, ends):
            path_train = build_miniworld_data_filename(
                env, start_env_id, end_env_id, dataset_config, mode=0)
            path_test = build_miniworld_data_filename(
                env, start_env_id, end_env_id, dataset_config, mode=1)

            paths_train.append(path_train)
            paths_test.append(path_test)

        filename = build_miniworld_model_filename(env, model_config)
        print(f"Generate filename: {filename}")

    else:
        raise NotImplementedError

    config = {
        'horizon': horizon,
        'state_dim': state_dim,
        'action_dim': action_dim,
        'n_layer': n_layer,
        'n_embd': n_embd,
        'n_head': n_head,
        'shuffle': shuffle,
        'dropout': dropout,
        'test': False,
        'store_gpu': True,
    }
    if env == 'miniworld':
        config.update({'image_size': 25, 'store_gpu': False})
        model = ImageTransformer(config).to(device)
    elif env == 'linear_bandit_train_lookahead_pred_reward':
        model = Transformer_new(config).to(device)
    elif env == 'linear_bandit_train_lookahead_pred_reward_opt_a':
        model = Transformer_new_opt_a(config).to(device)
    elif env == 'linear_bandit_train_AD':
        model = Transformer(config).to(device) ## original transformer with action head
    elif env == 'darkroom_heldout_lookahead_pred_reward':
        model = Transformer_new(config).to(device)
    else:
        model = Transformer(config).to(device)

    params = {
        # 'batch_size': 1536,
        'batch_size': batch_size,
        'shuffle': True,
    }
    
    # 1024 + 512 = 1536
    # 1024 + 256 = 1280
    # 1024 + 256 + 32 = 1312
    # 4096 + 1024 + 256 + 32 = 5408

    log_filename = f'figs/loss/{filename}_logs.txt'
    with open(log_filename, 'w') as f:
        pass
    def printw(string):
        """
        A drop-in replacement for print that also writes to a log file.
        """
        # Use the standard print function to print to the console
        print(string)

        # Write the same output to the log file
        with open(log_filename, 'a') as f:
            print(string, file=f)

    if not os.path.exists(f'figs/loss/{filename}/'):
        os.makedirs(f'figs/loss/{filename}/', exist_ok=True)

    output_dir = f'figs/loss/{filename}/'
    # logging.basicConfig(filename= f'{output_dir}/{filename}.log', 
    # encoding='utf-8', 
    # level=logging.DEBUG)

    tb_writer = SummaryWriter(log_dir=output_dir)

    if env == 'miniworld':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])



        params.update({'num_workers': 16,
                'prefetch_factor': 2,
                'persistent_workers': True,
                'pin_memory': True,
                'batch_size': 64,
                'worker_init_fn': worker_init_fn,
            })


        printw("Loading miniworld data...")
        train_dataset = ImageDataset(paths_train, config, transform)
        test_dataset = ImageDataset(paths_test, config, transform)
        printw("Done loading miniworld data")
    elif env == 'linear_bandit_train_lookahead_wt':

        train_dataset = Dataset_wt(path_train, config)
        test_dataset = Dataset_wt(path_test, config)

    elif env == 'linear_bandit_train_lookahead_pred_reward':

        train_dataset = Dataset_pred_reward(path_train, config)
        test_dataset = Dataset_pred_reward(path_test, config)

    elif env == 'linear_bandit_train_lookahead_pred_reward_opt_a':

        train_dataset = Dataset_pred_reward_opt_a(path_train, config)
        test_dataset = Dataset_pred_reward_opt_a(path_test, config)
    
    elif env == 'linear_bandit_train_AD':

        # load the pred reward and pred action (use just pred action)
        train_dataset = Dataset_pred_reward(path_train, config)
        test_dataset = Dataset_pred_reward(path_test, config)
    
    elif env == 'darkroom_heldout':

        train_dataset = Dataset_old(path_train, config)
        test_dataset = Dataset_old(path_test, config)
    
    elif env == 'darkroom_heldout_lookahead_pred_reward':

        train_dataset = Dataset_pred_reward(path_train, config)
        test_dataset = Dataset_pred_reward(path_test, config)

    else:
        
        train_dataset = Dataset(path_train, config)
        test_dataset = Dataset(path_test, config)

    train_loader = torch.utils.data.DataLoader(train_dataset, **params)
    test_loader = torch.utils.data.DataLoader(test_dataset, **params)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = StepLR(optimizer, step_size=500, gamma=0.99)
    loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')

    loss_fn_reward = torch.nn.MSELoss(reduction='sum')

    test_loss = []
    train_loss = []

    ###### for logging just the reward loss for training with both reward and action loss #######
    # train_loss_reward = []
    

    printw("Num train batches: " + str(len(train_loader)))
    printw("Num test batches: " + str(len(test_loader)))

    
    
    for epoch in trange(num_epochs):

        # EVALUATION, HERE WE ARE NOT USING ESTIMATED OPTIMAL ACTIONS
        # printw(f"Epoch: {epoch + 1}")
        start_time = time.time()
        with torch.no_grad():
            epoch_test_loss = 0.0
            for i, batch in enumerate(test_loader):
                # print(f"Batch {i} of {len(test_loader)}", end='\r')
                batch = {k: v.to(device) for k, v in batch.items()}

                if env == 'linear_bandit_train_lookahead_pred_reward':
                    
                    true_actions = batch['context_opt_actions']
                    true_rewards = batch['context_pred_rewards']

                    pred_actions, pred_rewards = model(batch)
                    

                    # true_actions = true_actions.unsqueeze(
                    #     1).repeat(1, pred_actions.shape[1], 1)
                    true_actions = true_actions.reshape(-1, action_dim)
                    pred_actions = pred_actions.reshape(-1, action_dim)

                    # true_rewards = true_rewards.reshape(-1, 1)
                    # pred_rewards = pred_rewards.reshape(-1, 1)

                    true_rewards = true_rewards.reshape(-1, 1)
                    pred_rewards = pred_rewards.reshape(-1, action_dim)

                    # next_action = torch.argmax(pred_actions, dim = 1) # choosing the pred action with max prob
                    next_action = torch.argmax(true_actions, dim = 1) # choosing the true action (it's a one hot vector)


                    size_ , _ = pred_rewards.shape
                    pred_rewards_next_action = pred_rewards[torch.arange(size_), next_action.squeeze()]
                    pred_rewards_next_action = pred_rewards_next_action.view(size_, 1)
                    
                    # ipdb.set_trace()

                    ###### Loss function for pred reward (we are no longer considering the pred action loss)
                    loss = loss_fn_reward(pred_rewards_next_action, true_rewards)

                    ###### Loss function for pred reward and pred next action
                    # loss = loss_fn(pred_actions, true_actions) + loss_fn_reward(pred_rewards_next_action, true_rewards)

                    ###### Loss function for pred reward
                    # loss = loss_fn_reward(pred_rewards_next_action, true_rewards)
                
                elif env == 'linear_bandit_train_lookahead_pred_reward_opt_a':
                    
                    true_actions = batch['context_opt_actions']
                    true_rewards = batch['context_pred_rewards']
                    true_opt_a = batch['context_pred_opt_a']

                    pred_actions, pred_rewards, pred_opt_a = model(batch)
                    

                    

                    true_actions = true_actions.reshape(-1, action_dim)
                    pred_actions = pred_actions.reshape(-1, action_dim)

                    # true_rewards = true_rewards.reshape(-1, 1)
                    # pred_rewards = pred_rewards.reshape(-1, action_dim)

                    true_rewards = true_rewards.reshape(-1, 1)
                    pred_rewards = pred_rewards.reshape(-1, action_dim)

                    # next_action = torch.argmax(pred_actions, dim = 1) # choosing the pred action with max prob
                    next_action = torch.argmax(true_actions, dim = 1) # choosing the true action (it's a one hot vector)
                    

                    size_ , _ = pred_rewards.shape
                    pred_rewards_next_action = pred_rewards[torch.arange(size_), next_action.squeeze()]
                    pred_rewards_next_action = pred_rewards_next_action.view(size_, 1)
                    
                    true_opt_a = true_opt_a.reshape(-1, action_dim)
                    pred_opt_a = pred_opt_a.reshape(-1, action_dim)

                    # ipdb.set_trace()

                    loss = loss_fn(pred_actions, true_actions) + loss_fn_reward(pred_rewards_next_action, true_rewards) + loss_fn(pred_opt_a, true_opt_a)
                
                elif env == 'linear_bandit_train_original' or env == 'linear_bandit_train_original_emp_opt':

                    true_actions = batch['context_opt_actions']
                    pred_actions = model(batch)
                    # true_actions = true_actions.unsqueeze(
                    #     1).repeat(1, pred_actions.shape[1], 1)
                    true_actions = true_actions.reshape(-1, action_dim)
                    pred_actions = pred_actions.reshape(-1, action_dim)

                    loss = loss_fn(pred_actions, true_actions)
                
                elif env == 'linear_bandit_train_AD':

                    true_actions = batch['context_opt_actions'] # this contains next predicted action
                    pred_actions = model(batch)
                    # true_actions = true_actions.unsqueeze(
                    #     1).repeat(1, pred_actions.shape[1], 1)
                    true_actions = true_actions.reshape(-1, action_dim)
                    pred_actions = pred_actions.reshape(-1, action_dim)

                    loss = loss_fn(pred_actions, true_actions)
                    
                elif env == "darkroom_heldout":
                
                    true_actions = batch['optimal_actions']
                    pred_actions = model(batch)
                    true_actions = true_actions.unsqueeze(
                        1).repeat(1, pred_actions.shape[1], 1)
                    true_actions = true_actions.reshape(-1, action_dim)
                    pred_actions = pred_actions.reshape(-1, action_dim)
                    
                    loss = loss_fn(pred_actions, true_actions)

                elif env == "darkroom_heldout_lookahead_pred_reward":
                
                    true_actions = batch['context_opt_actions']
                    true_rewards = batch['context_pred_rewards']

                    pred_actions, pred_rewards = model(batch)
                    

                    # true_actions = true_actions.unsqueeze(
                    #     1).repeat(1, pred_actions.shape[1], 1)
                    true_actions = true_actions.reshape(-1, action_dim)
                    pred_actions = pred_actions.reshape(-1, action_dim)

                    # true_rewards = true_rewards.reshape(-1, 1)
                    # pred_rewards = pred_rewards.reshape(-1, 1)

                    true_rewards = true_rewards.reshape(-1, 1)
                    pred_rewards = pred_rewards.reshape(-1, action_dim)

                    # next_action = torch.argmax(pred_actions, dim = 1) # choosing the pred action with max prob
                    next_action = torch.argmax(true_actions, dim = 1) # choosing the true action (it's a one hot vector)


                    size_ , _ = pred_rewards.shape
                    pred_rewards_next_action = pred_rewards[torch.arange(size_), next_action.squeeze()]
                    pred_rewards_next_action = pred_rewards_next_action.view(size_, 1)
                    
                    # ipdb.set_trace()

                    ###### Loss function for pred reward (we are no longer considering the pred action loss)
                    loss = loss_fn_reward(pred_rewards_next_action, true_rewards)


                else:
                    true_actions = batch['context_opt_actions']
                    pred_actions = model(batch)
                    # true_actions = true_actions.unsqueeze(
                    #     1).repeat(1, pred_actions.shape[1], 1)
                    true_actions = true_actions.reshape(-1, action_dim)
                    pred_actions = pred_actions.reshape(-1, action_dim)

                    loss = loss_fn(pred_actions, true_actions)
                
                
                epoch_test_loss += loss.item() / horizon

        test_loss.append(epoch_test_loss / len(test_dataset))
        end_time = time.time()
        logging.info(f"Test loss: {test_loss[-1]}, Eval time: {end_time - start_time}")
        tb_writer.add_scalar("Loss/test", test_loss[-1], epoch)
        if (epoch + 1) % 50 == 0 or (env == 'linear_bandit' and (epoch + 1) % 20 == 0):
            printw(f"Test loss: {test_loss[-1]}, Eval time: {end_time - start_time}")


        # TRAINING, HERE WE ARE USING ESTIMATED OPTIMAL ACTIONS
        # TRAINING, HERE WE ARE USING ESTIMATED OPTIMAL ACTIONS, For Pred Reward we are also predicting reward
        epoch_train_loss = 0.0
        start_time = time.time()

        ###### for logging just the reward loss for training with both reward and action loss #######
        # epoch_train_loss_reward = 0.0

        for i, batch in enumerate(train_loader):
            # print(f"Batch {i} of {len(train_loader)}", end='\r')
            batch = {k: v.to(device) for k, v in batch.items()}
            
            

            if env == 'linear_bandit_train_lookahead_wt':

                # true_actions = batch['optimal_actions']
                true_actions = batch['context_opt_actions']
                pred_actions = model(batch)

                sum_reward_forwards = batch['context_sum_rewards']
                
                
                # true_actions = true_actions.unsqueeze(
                #     1).repeat(1, pred_actions.shape[1], 1)
                true_actions = true_actions.reshape(-1, action_dim)
                
                pred_actions = pred_actions.reshape(-1, action_dim)

                sum_reward_forwards = sum_reward_forwards.reshape(-1, action_dim)
                # sum_reward_forwards = sum_reward_forwards[:,0]

                # # Get the index with the highest probability using argmax
                # argmax_matrix = torch.argmax(pred_actions, dim=1)

                # # Create a two-hot vector, where max_index is 2, while rest is 1
                # # one_hot_matrix = torch.ones_like(pred_actions)
                # one_hot_matrix = torch.zeros_like(pred_actions)
                # one_hot_matrix.scatter_(1, argmax_matrix.unsqueeze(1), 1)
                # sum_reward_forwards = torch.mul(one_hot_matrix,sum_reward_forwards)
                
                
                # weight the true_next_action with their emp means
                true_actions_ = torch.mul(true_actions,sum_reward_forwards)
                pred_actions_ = pred_actions  # torch.mul(pred_actions,sum_reward_forwards)
                # loss = loss_fn(pred_actions_, true_actions_)*(1.0/sum_reward_forwards.shape[0])
                
                optimizer.zero_grad()
                loss = loss_fn(pred_actions_, true_actions_)

                # loss_ = 1.0

                
                # loss = loss_fn(pred_actions, true_actions)*loss_

                #loss = loss_fn(pred_actions, true_actions)*torch.sum(mean_actions)*torch.sum(mean_actions)
                
                # loss = loss_fn(torch.sum(mean_actions, axis=1)@pred_actions, torch.sum(mean_actions, axis=1)@true_actions)
                # loss = loss_fn(pred_actions*mean_actions, true_actions*mean_actions)

                # ipdb.set_trace()
            
            elif env == 'linear_bandit_train_lookahead_pred_reward':
                
                # true_actions = batch['optimal_actions']
                true_actions = batch['context_opt_actions']
                true_rewards = batch['context_pred_rewards']

                pred_actions, pred_rewards = model(batch)
                
                
                # true_actions = true_actions.unsqueeze(
                #     1).repeat(1, pred_actions.shape[1], 1)
                true_actions = true_actions.reshape(-1, action_dim)
                pred_actions = pred_actions.reshape(-1, action_dim)

                # true_rewards = true_rewards.reshape(-1, 1)
                # pred_rewards = pred_rewards.reshape(-1, 1)

                true_rewards = true_rewards.reshape(-1, 1)
                pred_rewards = pred_rewards.reshape(-1, action_dim)

                # next_action = torch.argmax(pred_actions, dim = 1) # choosing the pred action with max prob
                next_action = torch.argmax(true_actions, dim = 1) # choosing the true action (it's a one hot vector)

                size_ , _ = pred_rewards.shape
                pred_rewards_next_action = pred_rewards[torch.arange(size_), next_action.squeeze()]
                pred_rewards_next_action = pred_rewards_next_action.view(size_, 1)
                # ipdb.set_trace()

                optimizer.zero_grad()

                ###### Loss function for pred reward (we are no longer considering the pred action loss)
                loss = loss_fn_reward(pred_rewards_next_action, true_rewards)

                ###### Loss function for pred reward and pred next action
                # loss = loss_fn(pred_actions, true_actions) + loss_fn_reward(pred_rewards_next_action, true_rewards)
                
                # for logging just the reward loss for training with both reward and action loss
                # loss_train_reward = loss_fn_reward(pred_rewards_next_action, true_rewards) 
                # ###### Loss function for pred reward
                # loss = loss_fn_reward(pred_rewards_next_action, true_rewards)

                # ipdb.set_trace()
            
            elif env == 'linear_bandit_train_lookahead_pred_reward_opt_a':
                
                # true_actions = batch['optimal_actions']
                true_actions = batch['context_opt_actions']
                true_rewards = batch['context_pred_rewards']
                true_opt_a = batch['context_pred_opt_a']

                pred_actions, pred_rewards, pred_opt_a = model(batch)
                    
                
                true_actions = true_actions.reshape(-1, action_dim)
                pred_actions = pred_actions.reshape(-1, action_dim)

                # true_rewards = true_rewards.reshape(-1, 1)
                # pred_rewards = pred_rewards.reshape(-1, action_dim)

                true_rewards = true_rewards.reshape(-1, 1)
                pred_rewards = pred_rewards.reshape(-1, action_dim)

                # next_action = torch.argmax(pred_actions, dim = 1) # choosing the pred action with max prob
                next_action = torch.argmax(true_actions, dim = 1) # choosing the true action (it's a one hot vector)
                    

                size_ , _ = pred_rewards.shape
                pred_rewards_next_action = pred_rewards[torch.arange(size_), next_action.squeeze()]
                pred_rewards_next_action = pred_rewards_next_action.view(size_, 1)
                    
                true_opt_a = true_opt_a.reshape(-1, action_dim)
                pred_opt_a = pred_opt_a.reshape(-1, action_dim)

                optimizer.zero_grad()   
                loss = loss_fn(pred_actions, true_actions) + loss_fn_reward(pred_rewards_next_action, true_rewards) + loss_fn(pred_opt_a, true_opt_a)

                
                
                # ipdb.set_trace()

            elif env == 'linear_bandit_train_original' or env == 'linear_bandit_train_original_emp_opt':

                true_actions = batch['context_opt_actions']
                pred_actions = model(batch)
                # true_actions = true_actions.unsqueeze(
                #     1).repeat(1, pred_actions.shape[1], 1)
                true_actions = true_actions.reshape(-1, action_dim)
                pred_actions = pred_actions.reshape(-1, action_dim)

                optimizer.zero_grad()  
                loss = loss_fn(pred_actions, true_actions)

                # ipdb.set_trace()
            
            elif env == 'linear_bandit_train_AD':
                
                true_actions = batch['context_opt_actions']
                pred_actions = model(batch)
                # true_actions = true_actions.unsqueeze(
                #     1).repeat(1, pred_actions.shape[1], 1)
                true_actions = true_actions.reshape(-1, action_dim)
                pred_actions = pred_actions.reshape(-1, action_dim)

                optimizer.zero_grad()  
                loss = loss_fn(pred_actions, true_actions)

            elif env == "darkroom_heldout":
                
                true_actions = batch['optimal_actions']
                pred_actions = model(batch)
                true_actions = true_actions.unsqueeze(
                    1).repeat(1, pred_actions.shape[1], 1)
                true_actions = true_actions.reshape(-1, action_dim)
                pred_actions = pred_actions.reshape(-1, action_dim)

                optimizer.zero_grad()
                loss = loss_fn(pred_actions, true_actions)

            elif env == "darkroom_heldout_lookahead_pred_reward":
                
                true_actions = batch['context_opt_actions']
                true_rewards = batch['context_pred_rewards']

                pred_actions, pred_rewards = model(batch)
                    

                # true_actions = true_actions.unsqueeze(
                #     1).repeat(1, pred_actions.shape[1], 1)
                true_actions = true_actions.reshape(-1, action_dim)
                pred_actions = pred_actions.reshape(-1, action_dim)

                # true_rewards = true_rewards.reshape(-1, 1)
                # pred_rewards = pred_rewards.reshape(-1, 1)

                true_rewards = true_rewards.reshape(-1, 1)
                pred_rewards = pred_rewards.reshape(-1, action_dim)

                # next_action = torch.argmax(pred_actions, dim = 1) # choosing the pred action with max prob
                next_action = torch.argmax(true_actions, dim = 1) # choosing the true action (it's a one hot vector)


                size_ , _ = pred_rewards.shape
                pred_rewards_next_action = pred_rewards[torch.arange(size_), next_action.squeeze()]
                pred_rewards_next_action = pred_rewards_next_action.view(size_, 1)
                    
                # ipdb.set_trace()

                ###### Loss function for pred reward (we are no longer considering the pred action loss)
                loss = loss_fn_reward(pred_rewards_next_action, true_rewards)

                
            else:
                
                
                true_actions = batch['context_opt_actions']
                pred_actions = model(batch)
                true_actions = true_actions.reshape(-1, action_dim)
                pred_actions = pred_actions.reshape(-1, action_dim)
                optimizer.zero_grad()
                loss = loss_fn(pred_actions, true_actions)
                

            # ipdb.set_trace()

            loss.backward()
            
            # ###### for logging just the reward loss for training with both reward and action loss #######
            # loss_train_reward.backward()


            optimizer.step()
            scheduler.step()
            epoch_train_loss += loss.item() / horizon

            ###### for logging just the reward loss for training with both reward and action loss #######
            # epoch_train_loss_reward += loss_train_reward.item() / horizon

        train_loss.append(epoch_train_loss / len(train_dataset))

        ###### for logging just the reward loss for training with both reward and action loss #######
        # train_loss_reward.append(epoch_train_loss_reward / len(train_dataset))


        end_time = time.time()

        logging.info(f"Train loss: {train_loss[-1]}, Train time: {end_time - start_time}")
        tb_writer.add_scalar("Loss/train", train_loss[-1], epoch)
        if (epoch + 1) % 50 == 0 or (env == 'linear_bandit' and (epoch + 1) % 10 == 0):
            printw(f"Train loss: {train_loss[-1]}, Train time: {end_time - start_time}, lr: {scheduler.get_last_lr()}")


        # LOGGING
        if (epoch + 1) % 50 == 0 or (env == 'linear_bandit' and (epoch + 1) % 20 == 0):
            torch.save(model.state_dict(),
                       f'/media/external/subho/DPT/models/{filename}_epoch{epoch+1}.pt')

        # PLOTTING
        if (epoch + 1) % 10 == 0:
            # printw(f"Epoch: {epoch + 1}")
            # printw(f"Test Loss:        {test_loss[-1]}")
            # printw(f"Train Loss:       {train_loss[-1]}")
            # printw("\n")

            plt.yscale('log')
            plt.plot(train_loss[1:], label="Train Loss")
            plt.plot(test_loss[1:], label="Test Loss")

            ###### for logging just the reward loss for training with both reward and action loss #######
            # plt.plot(train_loss_reward[1:], label="Train Loss (Just Reward)")

            plt.legend()
            plt.savefig(f"figs/loss/{filename}_train_loss.png")
            plt.clf()

    torch.save(model.state_dict(), f'/media/external/subho/DPT/models/{filename}.pt')
    tb_writer.flush()
    print("Done.")
